import os
import sys
import cv2
'''
Second stage test
'''

from sklearn.ensemble import RandomForestClassifier
import pymeanshift as pms
import pickle
import multiprocessing as mp
import numpy as np
import functions as fc
import sys
import tensorflow as tf
from sklearn.decomposition import PCA
slim = tf.contrib.slim

def runTest(clf,filenames,in_path,out_path, lock, kmeans=None, vgg=1, pca=None):
    count=0
    u_min,u_max,v_min,v_max=fc.get_corner(True)

    # location_file='/media/deeplearning/DATA/system/wbw/GDA_project/MyProject/hog_color_pca_results/location.txt'
    # file_dict={}
    # with open(location_file,'r') as f1:
    #     for line in f1:
    #         contents=line.split()
    #         filename=contents[0]
    #         v=int(contents[1])
    #         u=int(contents[2])
    #         file_dict[filename]=[v,u]

    if vgg is not None:
        print('using vgg.............')
        vgg=fc.vgg16()
        sess=tf.Session(config=tf.ConfigProto(log_device_placement=False,allow_soft_placement=True))
        sess.run(tf.global_variables_initializer())
        variables_to_restore=slim.get_model_variables()
        variables_to_restore = {vgg.name_in_checkpoint(var):var for var in variables_to_restore}
        restorer = tf.train.Saver(variables_to_restore)
        model_path='/media/deeplearning/DATA/system/wbw/tools/weights/vgg_16.ckpt'
        restorer.restore(sess, model_path)

    for filename in filenames:
        if count%1==0:
            print('processing %d/%d' %(count,len(filenames)))
        count+=1
        img0 = cv2.imread(os.path.join(in_path,filename))
        content=os.path.splitext(filename)[0]
        content=content.split('_')

        v=int(content[0])
        u=int(content[1])

        if (u>u_max or u<u_min or v>v_max or v<v_min):
            continue
        img=img0.copy()
        segmented_image, labels_image, number_regions=fc.segmentation(img)
        if kmeans is None and vgg is None:
            feature1 = fc.miscellaneousFeature(segmented_image,labels_image)
            feature2 = fc.featureGenSift(segmented_image, labels_image)
            if feature2 is not None:
                feature = np.hstack( (feature2, feature1) )
                pred = clf.predict(feature)
                if pred[0] == 1:
                    lock.acquire()
                    cv2.imwrite(os.path.join(out_path,str(v)+'_'+str(u)+'.jpg'),img0)
                    lock.release()

        elif vgg is not None:
            segmented_image=vgg.pre_process(segmented_image)
            feature = vgg.vgg_feature(sess, [segmented_image])
            feature=pca.transform(feature)
            pred = clf.predict(feature)
            if pred[0] == 1:
                lock.acquire()
                cv2.imwrite(os.path.join(out_path,str(v)+'_'+str(u)+'.jpg'),img0)
                lock.release()

        elif kmeans is not None:
            sift = fc.collect_sift(img, labels_image)
            if sift is not None:
                bow_feature=fc.build_bow_feature(kmeans,sift)
                pred = clf.predict(bow_feature)
                if pred[0] == 1:
                    lock.acquire()
                    cv2.imwrite(os.path.join(out_path,str(v)+'_'+str(u)+'.jpg'),img0)
                    lock.release()





def main():
    in_path = '/media/deeplearning/DATA/system/wbw/GDA_project/MyProject/hog_color_pca_results/posStore'
    classifiername='rf'


    vgg=1   # None (use vgg) or other arbichary value (not use vgg)
    kmeans=None
    if kmeans:
        with open('/media/deeplearning/DATA/system/wbw/GDA_project/MyProject/kmeans_clfs/kmeans(n=15).p','rb') as f:
            kmeans=pickle.load(f)
    with open('/media/deeplearning/DATA/system/wbw/GDA_project/MyProject/segmented_classify(change_segmentation)/pca_vgg.p','rb') as f:
        pca=pickle.load(f)
    print('choosing classifier: {}'.format(classifiername))
    print('vgg={},kmeans={}'.format(vgg,kmeans))
    print('====================================================================')
    tmp=input('press any key to continue.')
    if classifiername=='rf':
        out_path = '/media/deeplearning/DATA/system/wbw/GDA_project/MyProject/segmented_classify/RF/output/'
        if kmeans is not None:
            clf_file = '/media/deeplearning/DATA/system/wbw/GDA_project/MyProject/segmented_classify/RF/rf_bow.p'
        elif vgg is not None:
            clf_file = '/media/deeplearning/DATA/system/wbw/GDA_project/MyProject/segmented_classify(change_segmentation)/RF/rf_vgg.p'
        elif kmeans is None and vgg is None:
            clf_file = '/media/deeplearning/DATA/system/wbw/GDA_project/MyProject/segmented_classify/RF/rf_central_sift.p'
    elif classifiername=='svm':
        out_path = '/media/deeplearning/DATA/system/wbw/GDA_project/MyProject/segmented_classify/svm/output/'
        if kmeans is not None:
            clf_file = '/media/deeplearning/DATA/system/wbw/GDA_project/MyProject/segmented_classify/svm/svm_bow.p'
        elif vgg is not None:
            clf_file = '/media/deeplearning/DATA/system/wbw/GDA_project/MyProject/segmented_classify/svm/svm_vgg.p'
        elif vgg is None and kmeans is None:
            pass
    else:
        print('classifier name not valid!!!!!!!!!!!')
    f = open(clf_file, 'rb')
    clf = pickle.load(f)
    f.close()


    print('clear old files.........')
    for file_name in os.listdir(out_path):
        file_path=os.path.join(out_path,file_name)
        if os.path.isfile(file_path):
            os.unlink(file_path)

    count=0
    CPUnum=32
    if vgg is not None:
        CPUnum=1
    filenames = os.listdir(in_path)
    print('total files:',len(filenames))
    end=0
    starts=[]
    length=int( len(filenames)/CPUnum)
    procs=[]
    lock = mp.Lock()
    start=0
    for i in range(CPUnum):
        end=start+length
        if end>len(filenames) or i==CPUnum-1:
            end=len(filenames)
        proc=mp.Process(target=runTest, args=(clf,filenames[start:end],in_path,out_path,lock,kmeans,vgg,pca))    # specify if use kmeans!!!
        start=end
        procs.append(proc)
        proc.start()
    for proc in procs:
        proc.join()


if __name__=='__main__':
    main()
